{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5628e398",
   "metadata": {},
   "source": [
    "# The `Truncation` notebook  <a id=\"Truncation\"></a>[<font size=1>(back to `Main.ipynb`)</font>](./Main.ipynb)\n",
    "\n",
    "This notebook gathers all functions related to the construction of the truncated model of the Aiyagari model (i.e., for given fiscal policy). \n",
    "\n",
    "The intuition of the truncation method can be summed up in two steps.\n",
    "\n",
    "1. Aggregating together agents with the same recent idiosyncratic history (called *truncated history*).\n",
    "2. Expressing the model in terms of these groups of agents instead of individual agents. \n",
    "\n",
    "This generates the so-called **truncated model**, whose main advantage will be to feature a *finite state space*. \n",
    "\n",
    "The rest of this notebook consists in building the truncated model from the Aiyagari model. In terms of Julia objects, we will transform an `AiyagariSolution` (i.e., solution of the Aiyagari model) and a `Economy` (i.e., parameters characterizing the economy) into a `TruncatedModel` object. The details of the construction of the objects can be found in `Structures.ipynb`.\n",
    "\n",
    "### The notebook is organized as follows:\n",
    "1. [Constructing the set of truncated histories](#constructing-truncated-hist);\n",
    "2. [Characterizing the distribution of truncated histories](#transition-stat-dist);\n",
    "3. [Constructing the truncated model](#truncated).\n",
    "\n",
    "In each step, we briefly recall the formal aspects and its connection with the implementation. In particular, we specify how the mathematical objects are computationally represented. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d09db785",
   "metadata": {},
   "source": [
    "# Constructing the set of truncated histories   <a id=\"constructing-truncated-hist\"></a>[<font size=1>(back to menu)</font>](#Truncation)\n",
    "\n",
    "We construct:\n",
    "1. [*uniform* truncated histories](#uniform-truncated-hist) and\n",
    "2. [*refined* truncated histories](#refined-truncated-hist).\n",
    "\n",
    "\n",
    "## Contructing *uniform* truncated histories <a id=\"uniform-truncated-hist\"></a>[<font size=1>(back to truncated histories)</font>](#constructing-truncated-hist)\n",
    "\n",
    "[Presentation](#pres-uniform-truncated-hist) and [implementation](#imp-uniform-truncated-hist).\n",
    "\n",
    "### Presentation <a id=\"pres-uniform-truncated-hist\"></a>[<font size=1>(back to truncated histories)</font>](#uniform-truncated-hist)\n",
    "\n",
    "A truncated history is a vector of given length $N$ gathering idiosyncratic shocks for $N$ periods. For instance, if the idiosyncratic shock is denoted by $y$ taking value in a set $\\mathcal Y$,  $h=(\\tilde{y}_{-N+1},\\ldots,\\tilde{y}_{-1},\\tilde{y}_{0})\\in \\mathcal{Y}^N$ is an example of a truncated history of length $N$. An agent $i$ with idiosyncratic history $\\{y_{i,0},\\ldots,y_{i,t}\\}$ in period $t$ will have truncated history $h$ if their history for the last $N$ periods is $h$, i.e., if $(y_{i,t-n+1},\\ldots,y_{i,t})=(\\tilde{y}_{-N+1},\\ldots,\\tilde{y}_{-1},\\tilde{y}_{0})$.\n",
    "\n",
    "These truncated histories are called *uniform*, because they all have the same length $N$.\n",
    "\n",
    "The set of uniform truncated histories is thus simply $\\mathcal{Y}^N$. Our implementation features two twists. \n",
    "1. Truncated histories will be vectors, but ordered from the current state (first element) to the state $N-1$ periods ago (last element).\n",
    "2. The vector will not gather the idiosyncratic states themselves, but their corresponding index in the vector representing the set $\\mathcal Y$.\n",
    "\n",
    "Let us give an example. Let assume $\\mathcal Y = \\{y_1,y_2,y_3\\}$, which is represented by the vector `ys = [y1,y2,y3]` (element of the `struct Economy`, see `Structures.ipynb`). Let consider an agent whose idiosyncratic histories is $(\\ldots,y_3,y_2,y_2,y_1,y_1)$. This should be read as: the agent is currently in state $y_1$, was one period ago in state $y_1$, was 2 and 3 periods ago in state $y_2$ and 4 periods ago in state $y_3$. \n",
    "\n",
    "This will represented by the `Vector{Int}` equal to: `[1,1,2,2,3]` where we assume a truncation length equal to $N=5$ periods (i.e., reversed order and indices in `ys`). "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f596deb3",
   "metadata": {},
   "source": [
    "### Implementation <a id=\"imp-uniform-truncated-hist\"></a>[<font size=1>(back to truncated histories)</font>](#uniform-truncated-hist)\n",
    "\n",
    "The function: \n",
    "\n",
    ">`generateHistories(N::I,ny::I)::Vector{Vector{I}} where I<:Int` \n",
    "\n",
    "takes as inputs:\n",
    "\n",
    "* a truncation length `N` (same meaning as $N$);\n",
    "* a number of idiosyncractic states `ny`(same meaning as $Card(\\mathcal Y)$).\n",
    "\n",
    "Each of the two inputs is requested to be greater than 1. \n",
    "\n",
    "The function returns, as a `Vector`, all idiosyncratic histories. Each idiosyncratic history is also a `Vector` (of length `N`). Hence the output is of type `Vector{Vector{I}}` (vector of vectors). By construction the output is of cardinal $Card(\\mathcal Y)^N$ or `ny^N`.\n",
    "\n",
    "The implementation is functional in spirit and takes advantage of the functions `map` and `mapreduce`. The performance is not key since the function is only called once for the model truncation.\n",
    "\n",
    "***Remark.*** The order of the set of truncated histories does not matter. What matters is that is always iterated in the same order.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6ff9c24f",
   "metadata": {},
   "outputs": [],
   "source": [
    "function generateHistories(N::I,ny::I)::Vector{Vector{I}} where I<:Int\n",
    "    @assert N≥1 && ny≥1\n",
    "    if N==1\n",
    "        return map(n->[n],1:ny)\n",
    "    else\n",
    "        return mapreduce(h -> map(n->push!(copy(h),n),1:ny),vcat,generateHistories(N-1,ny))\n",
    "    end\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6234621",
   "metadata": {},
   "source": [
    "#### Two examples\n",
    "\n",
    "We check the output for: \n",
    "\n",
    "* `N=2` and `ny=2`\n",
    "* `N=3` and `ny=4`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4a790a0c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4-element Vector{Vector{Int64}}:\n",
       " [1, 1]\n",
       " [1, 2]\n",
       " [2, 1]\n",
       " [2, 2]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generateHistories(2,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ef95a9bf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "64-element Vector{Vector{Int64}}:\n",
       " [1, 1, 1]\n",
       " [1, 1, 2]\n",
       " [1, 1, 3]\n",
       " [1, 1, 4]\n",
       " [1, 2, 1]\n",
       " [1, 2, 2]\n",
       " [1, 2, 3]\n",
       " [1, 2, 4]\n",
       " [1, 3, 1]\n",
       " [1, 3, 2]\n",
       " [1, 3, 3]\n",
       " [1, 3, 4]\n",
       " [1, 4, 1]\n",
       " ⋮\n",
       " [4, 2, 1]\n",
       " [4, 2, 2]\n",
       " [4, 2, 3]\n",
       " [4, 2, 4]\n",
       " [4, 3, 1]\n",
       " [4, 3, 2]\n",
       " [4, 3, 3]\n",
       " [4, 3, 4]\n",
       " [4, 4, 1]\n",
       " [4, 4, 2]\n",
       " [4, 4, 3]\n",
       " [4, 4, 4]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generateHistories(3,4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f153c20f",
   "metadata": {},
   "source": [
    "## Constructing *refined* truncated histories <a id=\"refined-truncated-hist\"></a>[<font size=1>(back to truncated histories)</font>](#constructing-truncated-hist)\n",
    "\n",
    "[Presentation](#pres-refined-truncated-hist) and [implementation](#imp-refined-truncated-hist).\n",
    "\n",
    "### Presentation <a id=\"pres-refined-truncated-hist\"></a>[<font size=1>(back to truncated histories)</font>](#refined-truncated-hist)\n",
    "\n",
    "Although simple, uniform truncated histories have the drawback of growing exponentially with $N$, and of including some histories which arevery unlikely to be experienced by agents and hence of very small size.\n",
    "\n",
    "To solve this issue, an option (see LeGrand and Ragot, AES, 2022) is to consider different truncation lengths for different histories, that will be called *refined* truncated histories. Histories that are more likely to be experienced (i.e., larger ones) can be “refined”, meaning that they can be substituted by a set of histories with higher truncation lengths. \n",
    "\n",
    "For instance, the truncated history $(y_{1},y_{1})$ ($N=2$) can be refined into $\\{(y,y_{1},y_{1}):y\\in\\mathcal{Y}\\}$, where the group of agents who have been in productivity $y_{1}$ for two consecutive periods is split into $Card(\\mathcal{Y})$ truncated histories. \n",
    "\n",
    "The function we implement takes advantage of the fact that idiosyncratic states are persistent and that large histories are typically the ones where the agent remains in the same state for $N$ periods (e.g., $(y_{1},\\ldots ,y_{1})$ for $N$ periods). The function thus only refines these constant histories. In addition to being connected to data properties, this construction allows obtaining a well-defined partition of the set of idiosyncratic histories (which is not the case with an arbitrary refinement).\n",
    "\n",
    "The refined truncated is characterized by:\n",
    "* the common truncation length $N$,\n",
    "* the refined truncation length $N_y^r$ for each idiosyncratic state $y\\in\\mathcal Y$.\n",
    "\n",
    "\n",
    "The implementation uses the same representations as in `generateHistories`: the set of histories is a vector of vectors and each truncated history vector is ordered from the most recent to the latest. The vector of refined truncation lengths will be a vector `[N1,N2,...,Nny]` where for example `N2` is the refinement length of the idiosyncratic realization $y_2$.\n",
    "\n",
    "\n",
    "Let us give an example. We assume $\\mathcal Y = \\{y_1,y_2\\}$, with $N=2$ and $N^r = [4,3]$. This implies that all histories have a minimal length of 2 and that histories where the agent remain in state $y_1$ (resp. $y_2$) are refined up to length 4 (resp. 3).  The refinement can be thought about in 3 steps:\n",
    "1. We start with the set of uniform truncated histories ($N=2$), which is: `[[1, 1], [1, 2], [2, 1], [2, 2]]`.\n",
    "2. We then refine `[1, 1]` and `[2, 2]` once and obtain `[[1, 1, 1], [1, 1, 2], [1, 2], [2, 1], [2, 2, 1], [2, 2, 2]]`. Since $N^r_2=3$, the refinement of state 2 is done\n",
    "3. We refine `[1, 1, 1]` once more and we obtain `[[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2], [1, 2], [2, 1], [2, 2, 1], [2, 2, 2]]`. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a557b66c",
   "metadata": {},
   "source": [
    "### Implementation <a id=\"imp-refined-truncated-hist\"></a>[<font size=1>(back to truncated histories)</font>](#refined-truncated-hist)\n",
    "\n",
    "The implementation uses the same representations as in `generateHistories`: the set of histories is a vector of vectors and each truncated history vector is ordered from the most recent to the latest. We take advantage of multiple dispatch to define two versions of `generateHistories`. \n",
    "\n",
    "The function `generateHistories(N::I,refiNs::Vector{I})::Vector{Vector{I}} where I<:Int` takes as inputs:\n",
    "\n",
    "* a truncation length `N` (same meaning as $N$);\n",
    "* a vector of refined truncation length for each idiosyncractic state `refiNs`(same meaning as $(N^r_1,\\ldots,N^r_{Card(\\mathcal Y)})$): `refiNs[k]` gives the truncation length of the `k`th idiosyncratic state.\n",
    "\n",
    "We request `N ≥ 1` and `length(refiNs) ≥ 1`. If `refiNs[k] ≤ N` for some `k` then no refinement is done but the truncation up to `N` is kept. \n",
    "\n",
    "The function returns a vector of vectors of type `Vector{Vector{I}}` as in the uniform case -- with the difference that not all vectors are of the same size. Again, the implementation is functional and the performance is not key. \n",
    "\n",
    "The function works as follows:\n",
    "* it builds the set of uniform truncated histories `generateHistories(N,ny)` where `ny = length(refiNs)`;\n",
    "* it then refines histories using the function `refine_hist(refiNs::Vector{I},hs::Vector{Vector{I}},ny::I)::Vector{Vector{I}} where I<:Int`. This function is recursive on the vector `refiNs` and stops when all vector elements are negative. \n",
    "* the refinement for one individual history is the function `refine_one_hist(h::Vector{I},ny::I,refiNs::Vector{I})::Vector{Vector{I}} where I<:Int`. If `h` is an history with constant state and if this state needs to be refined (the corresponding element in `refiNs` is strictly positive), then the history `h` is substituted by a set of histories where one elemnt of `1:ny`has been added. \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a23f188f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "generateHistories (generic function with 3 methods)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function refine_one_hist(h::Vector{I},refiNs::Vector{I},ny::I)::Vector{Vector{I}} where I<:Int\n",
    "    if all(h.==h[1]) && refiNs[h[1]] > zero(I)\n",
    "        return map(n->push!(copy(h),n),1:ny)\n",
    "    else\n",
    "        return [copy(h)]\n",
    "    end\n",
    "end\n",
    "\n",
    "\n",
    "function refine_hist(refiNs::Vector{I},hs::Vector{Vector{I}},ny::I)::Vector{Vector{I}} where I<:Int\n",
    "    if all(refiNs .<= zero(I))\n",
    "        return hs\n",
    "    else\n",
    "        return refine_hist(refiNs.-1,mapreduce(h -> refine_one_hist(h,refiNs,ny),vcat,hs),ny)\n",
    "    end\n",
    "end\n",
    "\n",
    "function generateHistories(N::I,refiNs::Vector{I},ny::I)::Vector{Vector{I}} where I<:Int\n",
    "    @assert ny ≥ one(I) && N ≥ one(I) && ny ≤ length(refiNs)\n",
    "    return refine_hist(refiNs[1:ny] .- N, generateHistories(N,ny), ny)\n",
    "end\n",
    "function generateHistories(N::I,refiNs::Vector{I})::Vector{Vector{I}} where I<:Int\n",
    "    return generateHistories(N::I,refiNs,length(refiNs))\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f24fb8a2",
   "metadata": {},
   "source": [
    "#### Example\n",
    "\n",
    "We consider the example discussed above: $\\mathcal Y = \\{y_1,y_2\\}$, with $N=2$ and $N^r = [4,3]$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "76f4476c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7-element Vector{Vector{Int64}}:\n",
       " [1, 1, 1, 1]\n",
       " [1, 1, 1, 2]\n",
       " [1, 1, 2]\n",
       " [1, 2]\n",
       " [2, 1]\n",
       " [2, 2, 1]\n",
       " [2, 2, 2]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "N = 2\n",
    "refiNs = [4,3]\n",
    "generateHistories(N,refiNs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ca07f3a",
   "metadata": {},
   "source": [
    "#### Verification \n",
    "\n",
    "We use a deprecated function to check that outcomes are similar."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ea855ca1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "areEqual (generic function with 1 method)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# this is an old deprecated function\n",
    "function refine_one_hist_old(x::Vector{I},ny::I)::Vector{Vector{I}} where I<:Int\n",
    "    if all(x.==x[1])\n",
    "        return map(n->push!(copy(x),n),1:ny)\n",
    "    else\n",
    "        return [copy(x)]\n",
    "    end\n",
    "end\n",
    "\n",
    "function refine_hist_old(Nr::I,hs::Vector{Vector{I}},ny::I)::Vector{Vector{I}} where I<:Int\n",
    "    if Nr<=0\n",
    "        return hs\n",
    "    else\n",
    "        refine_hist_old(Nr-1,mapreduce(x -> refine_one_hist_old(x,ny),vcat,hs),ny)\n",
    "    end\n",
    "end\n",
    "\n",
    "function refinedbuild_histories_old(N::I,Vidio::Vector{I},ny::I)::Vector{Vector{I}} where I<:Int\n",
    "    \n",
    "    hs  = generateHistories(N,ny)\n",
    "    # first one\n",
    "    Vid = (repeat([1],N)) #  refine hte [1,..1] history\n",
    "    res_i=convertBasisE10(Vid,ny)\n",
    "    hfin = refine_hist_old(Vidio[1]-N,[hs[res_i]],ny)\n",
    "    res_n = res_i  # I store the index of the last refined history\n",
    "    for i=2:ny # refine other histories\n",
    "        Vid = (repeat([i],N))\n",
    "        res_i=convertBasisE10(Vid,ny)\n",
    "        hfin =  vcat(hfin,hs[res_n+1:res_i-1])\n",
    "        hr = refine_hist_old(Vidio[i]-N,[hs[res_i]],ny)\n",
    "        res_n = res_i\n",
    "        hfin =  vcat(hfin,hr)\n",
    "    end\n",
    "    return hfin\n",
    "end\n",
    "\n",
    "function areEqual(xs::Vector{Vector{T}},ys::Vector{Vector{T}}) where T<:Number\n",
    "    length(xs) !== length(ys) && return false\n",
    "    any(map(length,xs) .!== map(length,ys)) && return false\n",
    "    return all(map(x -> all(x .== zero(eltype(x))), xs.-ys))\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0f2f7aef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "391-element Vector{Vector{Int64}}:\n",
       " [1, 1, 1, 1]\n",
       " [1, 1, 1, 2]\n",
       " [1, 1, 1, 3]\n",
       " [1, 1, 1, 4]\n",
       " [1, 1, 1, 5]\n",
       " [1, 1, 1, 6]\n",
       " [1, 1, 1, 7]\n",
       " [1, 1, 2]\n",
       " [1, 1, 3]\n",
       " [1, 1, 4]\n",
       " [1, 1, 5]\n",
       " [1, 1, 6]\n",
       " [1, 1, 7]\n",
       " ⋮\n",
       " [7, 7, 7, 7, 7, 7, 2]\n",
       " [7, 7, 7, 7, 7, 7, 3]\n",
       " [7, 7, 7, 7, 7, 7, 4]\n",
       " [7, 7, 7, 7, 7, 7, 5]\n",
       " [7, 7, 7, 7, 7, 7, 6]\n",
       " [7, 7, 7, 7, 7, 7, 7, 1]\n",
       " [7, 7, 7, 7, 7, 7, 7, 2]\n",
       " [7, 7, 7, 7, 7, 7, 7, 3]\n",
       " [7, 7, 7, 7, 7, 7, 7, 4]\n",
       " [7, 7, 7, 7, 7, 7, 7, 5]\n",
       " [7, 7, 7, 7, 7, 7, 7, 6]\n",
       " [7, 7, 7, 7, 7, 7, 7, 7]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "using NBInclude\n",
    "using SparseArrays\n",
    "using Parameters\n",
    "@nbinclude(\"Structures.ipynb\");\n",
    "@nbinclude(\"Utils.ipynb\");\n",
    "refiNs = [4,3,3,5,3,2,8]\n",
    "N = 3\n",
    "\n",
    "xs = generateHistories(N,refiNs) \n",
    "ys = refinedbuild_histories_old(N,refiNs,length(refiNs))\n",
    "if areEqual(xs,ys)\n",
    "    xs\n",
    "else\n",
    "    @warn(\"xs and ys not equal\")\n",
    "end"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "027e4f70",
   "metadata": {},
   "source": [
    "# Characterizing the distribution of truncated histories <a id=\"transition-stat-dist\"></a>[<font size=1>(back to menu)</font>](#Truncation) \n",
    "\n",
    "We now characterize the distribution of truncated histories. More precisely, we compute:\n",
    "* the [sizes](#hist-sizes) of truncated histories: function `historySizes`;\n",
    "* the [transition matrix](#transition) between truncated histories: `historyTrans`;\n",
    "* the [distribution of truncated histories over the asset × productivity grid](#stat-dist-truncated)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4ed2734",
   "metadata": {},
   "source": [
    "## Size of truncated histories <a id=\"hist-sizes\"></a>[<font size=1>(back to transition matrix/stat dist)</font>](#transition-stat-dist)\n",
    "\n",
    "[Presentation](#pres-hist-sizes) and [implementation](#imp-hist-sizes).\n",
    "\n",
    "### Presentation <a id=\"pres-hist-sizes\"></a>[<font size=1>(back to transition matrix/stat dist)</font>](#hist-sizes)\n",
    "\n",
    "We are interested in computing the size of truncated histories, $(S_h)_{h\\in\\mathcal H}$. The idea is as follows. Consider a truncated history $h=(y^h_{n_h-1},\\ldots,y^h_0)\\in\\mathcal H$ of length $n_h$. The size of $h$ is the share of agents who experience history $h$ in the last $n_h$ periods. We compute $S_h$ as the share of agents with productivity $y^h_{n_h-1}$ multiplied by the probability to experience the successive transitions from $y^h_{n_h-1}$ to $y^h_0$ implied by $h$. Formally:\n",
    "$$ S_h = S^y_{y^h_{-n_h+1}} \\cdot \\prod_{j=1}^{n_h-1} \\Pi^y_{y^h_{-j}y^h_{-j+1}},$$\n",
    "where: $S^y_{y^h_{-n_h+1}}$ is the share of agents with productivity level $y^h_{-n_h+-1}$ and $\\Pi^y_{y^h_{-j}y^h_{-j+1}}$ is the probability to transit from productivity level $y^h_{-j}$ to $y^h_{-j+1}$."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dc60b47",
   "metadata": {},
   "source": [
    "### Implementation <a id=\"imp-hist-sizes\"></a>[<font size=1>(back to transition matrix/stat dist)</font>](#hist-sizes)\n",
    "\n",
    "The function: \n",
    "> `historySizes(hs::Vector{Vector{I}}, Πy::Matrix{T}; maxiter::I=1000000, tol::T=1e-16) where{I<:Integer, T<:Real}` \n",
    "\n",
    "takes as inputs:\n",
    "    \n",
    "* a set of truncated histories `hs` (with `hs[1]` being the current state and `hs[end]` the );\n",
    "* a transition matrix `Πy` between idiosyncratic histories (`Πy[hs[2],hs[1]]` is the transition probabilty from state `hs[2]` to state `hs[1]` (be careful of the order in `hs`);\n",
    "* control parameters `maxiter` and `tol` that are used to control when computing the stationary distribution over idiosyncratic states (using `powm!`).\n",
    "\n",
    "\n",
    "The function computes the stationary distribution. More precisely, it returns a tuple `(ind_h, S_h, y0h[ind_h])` with:\n",
    "* `ind_h` is the index of history with positive sizes: the size of elements `hs[ind_h]` is strictly positive and the one of others is null;\n",
    "* `S_h` is the size of positive histories (same dimension as `ind_h`);\n",
    "* `y0h[ind_h]` is the vector current productivity indices for positive size histories. \n",
    "\n",
    "***Remarks***\n",
    "* this function allows one to work only with truncated histories of non-zero sizes. In practice, this alllows one to diminish the number of histories to consider.\n",
    "* the size of `ind_h` is not known ex ante. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6a099618",
   "metadata": {},
   "outputs": [],
   "source": [
    "function historySizes(hs::Vector{Vector{I}}, Πy::Matrix{T}, Sy::Vector{T};\n",
    "    maxiter::I=1000000, tol::T=1e-16) where{I<:Integer, T<:Real}\n",
    "\n",
    "    # nb of productivity states\n",
    "    ny = size(Πy,1)       \n",
    "    \n",
    "    # initialization\n",
    "    Ntot = length(hs)   # total nb of truncated histories\n",
    "    Sh   = zeros(T, Ntot) \n",
    "                        # initialization of size of truncated histories\n",
    "    y0h   = zeros(I, Ntot) \n",
    "                        # initialization of current productivity indices of history\n",
    "\n",
    "    # We loop over all truncated histories\n",
    "    nh = zero(I)\n",
    "    for (ih,h) in enumerate(hs)\n",
    "        nh = length(h)\n",
    "        Sh[ih] = Sy[h[nh]] \n",
    "                    # distribution according to the \n",
    "                    # terminal productivity level\n",
    "        y0h[ih] = h[1]\n",
    "                    # current prod. level of h\n",
    "        for j = nh-1:-1:1\n",
    "            Sh[ih] *= Πy[h[j+1],h[j]] \n",
    "            # We move backwards and compute the probability to move \n",
    "            # from productivity h[j+1] to productivity h[j]\n",
    "        end  \n",
    "    end\n",
    "    \n",
    "    # we extract histories with non-zero size\n",
    "    ind_h, S_h = findnz(sparsevec(Sh))\n",
    "    return ind_h, S_h, y0h[ind_h]\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e00a6323",
   "metadata": {},
   "source": [
    "## Transition matrix  <a id=\"transition\"></a>[<font size=1>(back to transition matrix/stat dist)</font>](#transition-stat-dist)\n",
    "\n",
    "The function: \n",
    "\n",
    "> `historyTrans(hs::Vector{Vector{I}}, Πy::Matrix{T})::SparseMatrixCSC{T,I} where{I<:Integer, T<:Real}` \n",
    "\n",
    "takes as inputs (same before):    \n",
    "* a set of truncated histories `hs`;\n",
    "* a transition matrix `Πy` between idiosyncratic histories.\n",
    "\n",
    "The function returns a transition matrix `Πh` between the truncated histelements of `hs`. For instance `Πh[1,2]`is the transition probability between `hs[1]` and `hs[2]`.  The matrix `Πh` is a square matrix, which has the same number of rows (or columns) than the length of `hs`. The matrix `Πh` is a sparse matrix of type `SparseMatrixCSC{T,I}`.\n",
    "\n",
    "***Remarks.***\n",
    "* We check that the matrix is a proper transition matrix (summing to 1 in rows) using the function `isTransMat`. \n",
    "* The input `hs` must be such that the output is a well-defined transition matrix (it is safe to remove zero-size truncated histories)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cd4f058b",
   "metadata": {},
   "outputs": [],
   "source": [
    "function isTransMat(M::Matrix{T})::Bool where{T<:Real}\n",
    "    all(M.≥zero(T)) || return false             # if one element is negative\n",
    "    all(M.≤one(T))  || return false              # if one element is greater than 1\n",
    "    all(sum(M,dims=2) .≈ one(T))|| return false # if one row does not sum to 1\n",
    "    return true\n",
    "end\n",
    "\n",
    "function isTransMat(M::SparseMatrixCSC)\n",
    "    return isTransMat(Matrix(M))\n",
    "end\n",
    "\n",
    "function historyTrans(hs::Vector{Vector{I}},Πy::Matrix{T})::SparseMatrixCSC{T,I} where{I<:Integer,T<:Real}\n",
    "    # We construct a sparse square matrix Πh of length equal to the one of hs \n",
    "    Ntot    = length(hs)\n",
    "    Πh      = spzeros(T,Ntot,Ntot)\n",
    "    nh, nht = zero(I), zero(I)\n",
    "    \n",
    "    # we loop over histories \n",
    "    # we will compute the transition proba from history h\n",
    "    for (i,h) ∈ enumerate(hs)\n",
    "        nh = length(h)\n",
    "        # we loop over destination histories\n",
    "        # we compute transition proba from h to ht\n",
    "        for (j,ht) ∈ enumerate(hs)\n",
    "            nht = length(ht)\n",
    "            # we check that is a continuation of h \n",
    "            \n",
    "            # case 1: h is longer than ht\n",
    "            if (nh≥nht) && all(ht[2:end] .== h[1:nht-1])\n",
    "                Πh[i,j] = Πy[h[1],ht[1]]\n",
    "            # case 2: ht is longer than h\n",
    "            elseif (nht>nh) && all(ht[2:(nh+1)] .== h)\n",
    "                Πh[i,j] = Πy[h[1],ht[1]]\n",
    "            end\n",
    "        end\n",
    "    end\n",
    "    #print_array(Matrix(Πh))\n",
    "    #@show sum(Πh), sum(Πh,dims=1), sum(Πh,dims=2)\n",
    "    @assert isTransMat(Πh) # checking that the matrix is actually a transition matrix.\n",
    "    return Πh\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c97ea7a",
   "metadata": {},
   "source": [
    "## Distribution of truncated histories over the asset × productivity grid  <a id=\"stat-dist-truncated\"></a>[<font size=1>(back to transition matrix/stat dist)</font>](#transition-stat-dist)\n",
    "\n",
    "\n",
    "[Presentation](#pres-hist-product-grid) and [implementation](#imp-hist-product-grid).\n",
    "\n",
    "### Presentation <a id=\"pres-hist-product-grid\"></a>[<font size=1>(back to stationary distribution product grid)</font>](#stat-dist-truncated)\n",
    "\n",
    "We compute the stationary distribution of histories over the product grid asset $\\times$ producitvity. The resulting matrix has dimension $(n_a\\cdot n_y,N_{tot})$, where $N_{tot}$ is the number of truncated histories. The idea is very similar to the computation of the stationary distribution for `historyTrans`. \n",
    "\n",
    "As before, consider a truncated history $h=(y^h_{n_h-1},\\ldots,y^h_0)\\in\\mathcal H$ of length $n_h$.  We compute $\\tilde\\Lambda_{a,y,h}$ the distribution of agents with history $h$, asset holding $a$, and productivity level $y$. In practice, for $a$, we only consider the values of $a$ in the grid $a_\\text{Grid}$. Hence, the distribution is actually: $(\\tilde\\Lambda_{a_{\\text{Grid},i_a},y,h})_{i_a=1,\\ldots,n_a,y\\in\\mathcal Y, h\\in\\mathcal H}$. We compute it as the share of agents with asset holdings $a$ and productivity  level $y^h_{n_h-1}$ (i.e., $\\Lambda_{a_{\\text{Grid},i_a},y,h}$ on the product grid asset$\\times$productivity) multiplied by the probability to experience the successive transitions from $y^h_{n_h-1}$ to $y^h_0$ implied by $h$. Formally:\n",
    "\\begin{align*}\n",
    "     \\tilde\\Lambda_{a_{\\text{Grid},i_a},y,h} &= \\begin{cases}\\Lambda_{a_{\\text{Grid},i_a},y} \\cdot \\prod_{j=1}^{n_h-1} \\Pi^y_{y^h_{-j}y^h_{-j+1}}, \\text{ if }y=y_0^h,\\\\\n",
    "    0, \\text{ otherwise}.\\end{cases}\n",
    "\\end{align*}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c63fa003",
   "metadata": {},
   "source": [
    "### Implementation <a id=\"imp-hist-product-grid\"></a>[<font size=1>(back to stationary distribution product grid)</font>](#stat-dist-truncated)\n",
    "\n",
    "\n",
    "The function: \n",
    "\n",
    "> `historyDist(hs::Vector{Vector{I}},stationaryDist::Matrix{T},transitMat::SparseMatrixCSC{T,I},ny::I)::Matrix{T} where{I<:Integer, T<:Real}` \n",
    "    \n",
    "takes as inputs:\n",
    "    \n",
    "* a set of truncated histories `hs`;\n",
    "* a stationary distribution `stationaryDist` over the product grid productivity × asset (as computed when solving the Aiyagari model in [SolveAiyagari.ipynb](./SolveAiyagari.ipynb));\n",
    "* a transition matrix `transitMat` between elements of the product grid productivity × asset (as computed when solving the Aiyagari model in [SolveAiyagari.ipynb](./SolveAiyagari.ipynb));\n",
    "* the number of productivity levels `ny`.  \n",
    "\n",
    "The function returns a $((n_a \\times n_y), N_{tot})$-matrix (`Matrix` type). The result matrix contains the distribution of truncated histories over the product grid asset $\\times$ productivity. More precisely, if `statDist_h = historyDist(hs, stationaryDist, transitMat, ny)`, then  `statDist_h[:,h]` is the distribution of truncated history `h` over the product grid asset $\\times$ productivity.\n",
    "\n",
    "The implementation is very similar to the one of `historySizes`, except the initialization that is done on the product grid. However, the management of the product grid involves being careful with matrix products. \n",
    "\n",
    "***Remark:*** the stationary distribution and the transition matrix come from the resolution of the Aiyagari model (see [Aiyagari notebook](./SolveAiyagari.ipynb))."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59142d14",
   "metadata": {},
   "outputs": [],
   "source": [
    "function historyDist(hs::Vector{Vector{I}},stationaryDist::Matrix{T}, \n",
    "    transitMat::SparseMatrixCSC{T,I},ny::I)::Matrix{T} where{I<:Integer, T<:Real}\n",
    "    transitMat = Matrix(transitMat')  # we express the transition matrix column-wise\n",
    "    na   = div(size(transitMat,1),ny) # size of asset grid \n",
    "    Ntot = length(hs)                 # nb of truncated histories\n",
    "    \n",
    "    # initialization of the stat distribution: matrix of size (na*ny, Ntot)\n",
    "    statDist_h = zeros(T, na*ny, Ntot)\n",
    "    temp_v = zeros(T, na)\n",
    "    nh = zero(I)\n",
    "    \n",
    "    # Loop over all truncated histories\n",
    "    for (ih,h) in enumerate(hs)\n",
    "        shift0 = (h[end]-1)*na # shift for the index of h[end] in the product grid\n",
    "        # we initialize with the stationary distribution \n",
    "        # corresponding to the terminal productivity state h[end]. \n",
    "        statDist_h[1+shift0:na+shift0,ih] =  stationaryDist[:,h[end]]\n",
    "        nh = length(h)\n",
    "        for j=nh-1:-1:1\n",
    "            # shifts for accounting for the asset grid\n",
    "            shift1 = (h[j+1]-1)*na\n",
    "            shift0 = (h[j]-1)*na\n",
    "            \n",
    "            # we update the stationary distribution using the transition matrix\n",
    "            temp_v .= statDist_h[1+shift1:na+shift1,ih]\n",
    "            statDist_h[:,ih] .= zero(T)\n",
    "            statDist_h[1+shift0:na+shift0,ih] .= (\n",
    "                transitMat[1+shift0:na+shift0,1+shift1:na+shift1] * temp_v)\n",
    "            # we update the probability at step j+1 by multiplying it by the transition from \n",
    "            # state h[j+1] to h[j] (transitMat is column-wise)\n",
    "        end\n",
    "    end\n",
    "    return statDist_h\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "401c3467",
   "metadata": {},
   "source": [
    "# The truncated model <a id=\"truncated\"></a>[<font size=1>(back to menu)</font>](#Truncation) \n",
    "\n",
    "We now proceed with the construction of the truncated model. The construction is done in three steps: \n",
    "1. The characterizion [credit-constrained histories](#cc-truncated);\n",
    "2. The [truncation of a specific policy function](#pol-truncated);\n",
    "3. The construction of the [truncated model](#mod-truncated) itself."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67137a4f",
   "metadata": {},
   "source": [
    "## Credit-constrained histories <a id=\"cc-truncated\"></a>[<font size=1>(back to truncated model)</font>](#truncated) \n",
    "\n",
    "We explain how to construct the set $\\mathcal C$ of credit-constrained truncated histories: [presentation](#pres-cc) and [implementation](#imp-cc).\n",
    "\n",
    "### Presentation <a id=\"pres-cc\"></a>[<font size=1>(back to credit-constrained histories)</font>](#cc-truncated)\n",
    "\n",
    "The objective is to determine the set $\\mathcal C$ of credit-constrained truncated histories. The constraint is that the share of credit-constrained truncated histories is as close as possible to the share of credit-constrained agents in the Aiyagari model. \n",
    "\n",
    "The algorithm is as follows:\n",
    "1. Sort the truncated histories according to a given criterion (typically the Lagrange multipliers on the credit-constraints) and denote by $(h_{k_i})_{i=1,\\ldots,N_{tot}}$ the sorted set.\n",
    "2. The number $i_{cc}$ of credit-constrained histories is given by:\n",
    "$$i_{cc} = \\arg\\min \\biggl\\{j\\in\\{1,\\ldots,N_{tot}\\}: d\\bigl(\\sum_{i=1}^j S_{h_{k_i}},\\Lambda_{cc}\\bigr)\\biggr\\},$$\n",
    "where: \n",
    "    * $(S_h)_h$ is the size of truncated histories.\n",
    "    * $d$ is a map that formalizes the spread notion we care about. We will consider two cases: (i) $d$ is the Euclidian distance and (ii) $d(x,y) = (x-y)1_{x\\ge y}$ (which is a pseudo-distance since it is not symmetric, but verifies other properties). \n",
    "    * $\\Lambda_{cc}$ is the share of  credit-constrained agents in the Aiyagari model\n",
    "3. The set of credit-constrained histories is then given by:\n",
    "$$\\mathcal C = \\{h_{k_i}:i=1,\\ldots,i_{cc}\\}.$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7d99606",
   "metadata": {},
   "source": [
    "### Implementation <a id=\"imp-cc\"></a>[<font size=1>(back to credit-constrained histories)</font>](#cc-truncated)\n",
    "\n",
    "The function \n",
    "> `credit_constrained_h(shareCC::T, S_h::Vector{T}, x_h::Vector{T}; method::String=\"closest\", rev::Bool=true)` \n",
    "\n",
    "takes as input:\n",
    "* `shareCC`: the target share $\\Lambda_{cc}$ of credit-constrained agents;\n",
    "* `S_h`: the vector of sizes $(S_h)_h$ of truncated histories;\n",
    "* `x_h`: a vector of the same size as `S_h` along which the sorting of truncated histories is done;\n",
    "* `method`: a string characterizing the selection mechanism, set by default to `\"closest\"` (similar to specifying the pseudo-distance $d$);\n",
    "* `rev`: a Boolean (by default true) to specify the order of the sorting (true implies decreasing order).\n",
    "\n",
    "The function returns a pair:\n",
    "* an integer `i_cc` indicating the number of credit-constrained truncated histories;\n",
    "* the vector of indices of credit-constrained truncated histories.\n",
    "\n",
    "Typically, the function returns  the first `i_cc` truncated histories whose cumulative size  is the most *similar* to the target `shareCC`. The meaning of *similar* is specified by `method`: either the first larger or the closest (see formal definitions of $d$ above). The truncated histories are by default ordered along the decreasing order of `x_h` (which are typically Lagrange multipliers on the credit constraint). \n",
    "        \n",
    "\n",
    "\n",
    "***Remarks.***\n",
    "\n",
    "1. This function is key for the accuracy of the simulation. \n",
    "\n",
    "2. It is possible to choose a ranking along another quantity than Lagrange multipliers (e.g., end-of-period savings). However, be careful that it may then be necessary to change the order of sorting. In the case of end-of-period savings, we should set `rev` to false."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "948461f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "function credit_constrained_h(shareCC::T, S_h::Vector{T}, x_h::Vector{T}; \n",
    "                method::String=\"closest\",rev::Bool=true) where{T<:Real}\n",
    "    # We sort S_h along x_h following the order rev\n",
    "    ind_x_sort = sortperm(x_h,rev=rev)\n",
    "    S_h_sorted = S_h[ind_x_sort]\n",
    "    \n",
    "    if method==\"closest\"\n",
    "        i_cc = argmin(abs.(cumsum(S_h_sorted) .- shareCC)) #closest distance\n",
    "    else \n",
    "        i_cc = findfirst(x->x>zero(T), cumsum(S_h_sorted) .- shareCC)#at least as large\n",
    "    end\n",
    "    i_c = isnothing(i_cc) ? 1 : i_cc # we return at least one credit-constrained history\n",
    "\n",
    "    return i_c, ind_x_sort[1:i_c]\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2548e22b",
   "metadata": {},
   "source": [
    "## Truncating a policy function <a id=\"pol-truncated\"></a>[<font size=1>(back to truncated model)</font>](#truncated) \n",
    "\n",
    "We explain how to truncate a specific policy function: [presentation](#pres-pol) and [implementation](#imp-pol)\n",
    "\n",
    "### Presentation <a id=\"pres-pol\"></a>[<font size=1>(back to truncating a policy function)</font>](#pol-truncated)\n",
    "\n",
    "We consider a policy function $x$ defined on the product grid $a_\\text{Grid}\\times \\mathcal Y$. We denote by $\\tilde \\Lambda:(a,y,h)\\mapsto \\tilde \\Lambda(a,y,h)$ the distribution of truncated histories over the product grid. We define the truncated policy function as follows. For all $h\\in \\mathcal H$:\n",
    "$$x_h = \\frac{\\int_{a}\\sum_y x(a,y)\\tilde \\Lambda(da,y,h)}{\\int_{a}\\sum_y \\tilde \\Lambda(da,y,h)}.$$ In words each truncated history is assigned the average value of $x$ computed within the truncated history.\n",
    "\n",
    "In case of a discrete distribution  $(\\tilde\\Lambda(i_a,i_y,h))_{i_a = 1,\\ldots,n_a,i_y = 1,\\ldots,n_{\\mathcal Y},h \\in \\mathcal H}$ and a policy function defined on the same set, $(x(i_a,i_y,h))_{i_a,i_y,h}$, then $x_h$ can be computed as follows for all $h$:\n",
    "$$x_h = \\frac{\\sum_{i_a,i_y}\\tilde\\Lambda(i_a,i_y,h)\\,x(i_a,i_y,h)}{\\sum_{i_a,i_y} \\tilde\\Lambda(i_a,i_y,h)},$$\n",
    "which the expression that we actually implement."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aaceb81d",
   "metadata": {},
   "source": [
    "### Implementation <a id=\"imp-pol\"></a>[<font size=1>(back to truncating a policy function)</font>](#pol-truncated)\n",
    "\n",
    "The function:  \n",
    "> `truncate_polfun(gx::Matrix{T}, statDist_h::Matrix{T}; f=identity)::Vector{T} where{T<:Real}` \n",
    "\n",
    "computes for the individual policy function `gx` and for the distribution `statDist_h` over truncated histories the aggregation of the variable `x` for all truncated histories. \n",
    "\n",
    "The function takes as inputs:\n",
    "* a policy function `gx` which is a matrix of dimension $(n_a,n_y)$;\n",
    "* a distribution `statDist_h` of truncated histories over the product grid asset $\\times$ productivity. This a matrix of dimension $(n_a \\times n_y, N_{tot})$, where $N_{tot}$ is the number of truncated histories.\n",
    "* a function `f` with which we can transform `gx` (we then compute the aggregation of `f(gx)` which generally differs from `f(xh)`.\n",
    "\n",
    "The function returns a `Vector` of length  $N_{tot}$ equal to $(x_h)_{h\\in\\mathcal{H}}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "170828fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "function truncate_polfun(gx::Matrix{T}, statDist_h::Matrix{T}; f=identity) where{T<:Real}\n",
    "    na,ny  = size(gx)\n",
    "    n,Ntot = size(statDist_h)\n",
    "    @assert n==na*ny\n",
    "    p_ijk = reshape(statDist_h,na,ny,Ntot) \n",
    "    xhs = zeros(T, Ntot)\n",
    "    for k in 1:Ntot\n",
    "        p_j = sum(p_ijk[:,:,k],dims=1)\n",
    "        x_j = sum(p_ijk[:,:,k].*gx,dims=1)\n",
    "        f_x_j = zeros(T, size(x_j))\n",
    "        f_x_j[x_j.!=zero(T)] = f.(x_j[x_j.!=zero(T)]./(sum(p_ijk[:,:,k],dims=1)[x_j.!=zero(T)]))\n",
    "        xhs[k] = (sum(p_j.*f_x_j,dims=2)./sum(p_j,dims=2))[1,1]\n",
    "    end\n",
    "    return xhs\n",
    "end;"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b1097a8",
   "metadata": {},
   "source": [
    "## Computing the truncated model <a id=\"mod-truncated\"></a> [<font size=1>(back to truncated model)</font>](#truncated) \n",
    "\n",
    "We explain how to compute the truncated model: [presentation](#pres-mod) and [implementation](#imp-mod)\n",
    "\n",
    "### Presentation <a id=\"pres-mod\"></a>[<font size=1>(back to computing the truncated model)</font>](#mod-truncated)\n",
    "\n",
    "The construction of the truncated model is rather straightfoward and mostly relies on previous functions. The algortihm can be summarized as follows:\n",
    "1. Generate truncated histories using [`generateHistories`](#constructing-truncated-hist);\n",
    "2. Characterize the distribution (size and transition matrix) of truncated histories and restrict to positive-size histories using [`historySizes`](#hist-sizes) and [`historyTrans`](#transition);\n",
    "3. Compute the distribution of truncated histories on the product grid asset$\\,\\times\\,$productivity using [`historyDist`](#stat-dist-truncated);\n",
    "4. Compute the set of credit-constrainedhistories using [`credit_constrained_h`](#cc-truncated) and ompute the truncated allocation using policy functions and [`truncate_polfun`](#pol-truncated). We then have our instance of the struct `TruncatedAllocation`;\n",
    "5. Compute the $\\xi$s and return the instance of the struct `ξs_struct`;\n",
    "6. With the two latter steps, we can return the truncated model as an instance of `TruncatedModel`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bb47b78",
   "metadata": {},
   "source": [
    "### Implementation <a id=\"imp-mod\"></a>[<font size=1>(back to computing the truncated model)</font>](#mod-truncated)\n",
    "\n",
    "**This is the *core function* of the `Truncation` notebook**.\n",
    "\n",
    "\n",
    "The function: \n",
    "> `TruncatedModel(N::Integer,refiNs::Vector{Int64}, # length of the truncation\n",
    "                             solution::AiyagariSolution, economy::Economy;\n",
    "                             maxiter=1000000,tol=1e-16)` \n",
    "\n",
    "takes as input:\n",
    "* truncation length parameters: the common truncation length `N` and the vector of truncation length `refiNs` for each productivity state;\n",
    "* a solution of the Aiyagari model,  `solution`;\n",
    "* an economy characterized by `economy`;\n",
    "* convergence control parameters `maxiter` and `tol` for the computation of the stationary distribution.\n",
    "\n",
    "The function returns a `TruncatedModel` (actually, the function can be seen as a specific constructor). See [Structures.ipynb](./Structures.ipynb) for more details about this structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "217ce1a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "function TruncatedModel(N::Integer,refiNs::Vector{Int64}, # length of the truncation\n",
    "                             solution::AiyagariSolution, economy::Economy;\n",
    "                             maxiter=1000000,tol=1e-16,method::String=\"first_larger\")\n",
    "\n",
    "    @unpack β,α,δ,u,u′,u′′,l_supply,τ,na,a_min,aGrid,ny,ys,Πy,Sy,nys = economy\n",
    "    @unpack ga,gc,gl,R,w, A,K,L,transitMat,stationaryDist,residEuler = solution\n",
    "    T = typeof(β)\n",
    "\n",
    "    # Identifying positive-size truncated histories\n",
    "    #hs′  = generateHistories(N,refiNs)\n",
    "\n",
    "    \n",
    "    # Constructing the refined history as the same structure, reproduced 3 times, with relevant indices.\n",
    "    @assert length(refiNs) ≥ maximum(nys) \n",
    "    hist_ys = Array{Vector{Vector{typeof(ny)}},1}(undef, length(nys))\n",
    "    for i in eachindex(nys)\n",
    "        hist_ys[i] = generateHistories(N,refiNs,nys[i])\n",
    "        if (i > 1) \n",
    "            for j in eachindex(hist_ys[i])\n",
    "                hist_ys[i][j] .+= sum(nys[1:i-1])\n",
    "            end\n",
    "        end\n",
    "    end\n",
    "    hs′   = vcat(hist_ys...)\n",
    "    Ntots = cumsum(map(length, hist_ys))\n",
    "    \n",
    "    \n",
    "    (ind_h, S_h, ind_y0_h) = historySizes(hs′,Πy,Sy)\n",
    "    ind_hs = Array{Vector{typeof(ny)},1}(undef,length(Ntots))\n",
    "    #@show typeof(ind_hs)\n",
    "    for i in eachindex(Ntots)\n",
    "        if i==1\n",
    "            #@show typeof(ind_h[ind_h.≤Ntots[1]])\n",
    "            ind_hs[1] = ind_h[ind_h.≤Ntots[1]]\n",
    "        else\n",
    "            ind_hs[i] = ind_h[(ind_h.>Ntots[i-1]).&&(ind_h.≤Ntots[i])]\n",
    "        end\n",
    "    end\n",
    "    \n",
    "\n",
    "    # We consider only non-zero histories.\n",
    "    hs   = hs′[ind_h] # Important : I redefine the histories here to consider the relevant one\n",
    "    y0_h = [ys[i] for i in ind_y0_h]\n",
    "    \n",
    "    Π_h  = historyTrans(hs,Πy)              # transition matrix\n",
    "\n",
    "    # Transition and distribution of truncated historoes\n",
    "    statDist_h = historyDist(hs, stationaryDist,         # distribution\n",
    "                    transitMat, ny) \n",
    "    \n",
    "    # Computing truncated allocation\n",
    "    c_h        = truncate_polfun(gc, statDist_h)         # consumption\n",
    "    a_beg_h    = truncate_polfun(repeat(aGrid,1,ny),  # beginning-of-period wealth\n",
    "                                statDist_h)\n",
    "    a_end_h    = truncate_polfun(ga, statDist_h)         # end-of-period wealth\n",
    "   \n",
    "    u_h        = truncate_polfun(u.(gc,gl), statDist_h, f=x->x)     # utility of consumption\n",
    "    u′_h       = truncate_polfun(u′.(gc,gl), statDist_h, f=x->x)    # maginal utility of consumption\n",
    "    u′′_h      = truncate_polfun(u′′.(gc,gl), statDist_h, f=x->x)   # 2nd derivative of utility of consumption\n",
    "    \n",
    "    v_h        = truncate_polfun(0.0 .*(gl), statDist_h, f=x->x)    # disutility of labor\n",
    "    v′_h       = truncate_polfun(0.0 .*(gl), statDist_h, f=x->x)   # marginal disutility of labor\n",
    "    l_h        = truncate_polfun(gl, statDist_h)         # labor supply\n",
    "    \n",
    "    ly_τ_h     = truncate_polfun((repeat(ys',na,1).*gl).^τ, # efficient labor supply (y×l)^(1-τ)\n",
    "                                    statDist_h, f=x->x) \n",
    "    \n",
    "    resid_E_h  = truncate_polfun(residEuler, statDist_h)  # Euler equation residuals, to compute credit constrained history\n",
    "\n",
    "    # Credit-constrained histories\n",
    "    ind = findall(x->x<=1e-9,ga[:])\n",
    "    dist = solution.stationaryDist[:]\n",
    "    share_cc  = sum(dist[ind])\n",
    "\n",
    "    indtype= cumsum(economy.nys*economy.na) # index of each type of ex ante agents\n",
    "    dis    = solution.stationaryDist[:]\n",
    "    ind1 = findall(x->1<=x<=indtype[1],ind)\n",
    "    ind2 = findall(x->indtype[1]+1<=x<indtype[2],ind)\n",
    "    ind3 = findall(x->indtype[2]+1<=x<indtype[3],ind)\n",
    "    sr   =zeros(3,1)\n",
    "    sr[1]=sum(dis[ind[ind1]])\n",
    "    sr[2]=sum(dis[ind[ind2]])\n",
    "    sr[3]=sum(dis[ind[ind3]])\n",
    "\n",
    "    # computing the share of credit-constrained histories for each type\n",
    "    nh = zeros(Int,length(economy.nys)) # number of histories of each type\n",
    "    share = zeros(Float64,length(economy.nys)) # share by types\n",
    "    nh[1] = length(findall(x -> x <= economy.nys[1],ind_y0_h))\n",
    "    for i in eachindex(nh)\n",
    "        if i==1 \n",
    "            nh[i] = length(findall(x -> x <= economy.nys[i],ind_y0_h))\n",
    "        else\n",
    "            nh[i] = length(findall(x -> economy.nys[i-1]<x <= (economy.nys[i-1]+economy.nys[i]),ind_y0_h))\n",
    "        end\n",
    "    end\n",
    "\n",
    "    # Computing the indices of credit-constrained histories for each type\n",
    "    nb_cc_h_XRv, ind_cc_h_XRv =0,0\n",
    "    starti = 1\n",
    "    endi   = nh[1]\n",
    "    for i in eachindex(sr)\n",
    "       # share[i] = sum(S_h[ind_cc_h[findall(x -> starti<=x <= endi,ind_cc_h)]])   \n",
    "        nb_cc_h_XR, ind_cc_h_XR = credit_constrained_h(sr[i],   # Credit-constraint histories\n",
    "           S_h[starti:endi],resid_E_h[starti:endi],method=method,rev=true)\n",
    "           \n",
    "           nb_cc_h_XRv = vcat(nb_cc_h_XRv,nb_cc_h_XR )\n",
    "           ind_cc_h_XRv = vcat(ind_cc_h_XRv,ind_cc_h_XR .+ (starti-1))   \n",
    "\n",
    "        starti = 1 + endi\n",
    "        endi = endi + nh[i]        \n",
    "    end \n",
    "    nb_cc_h_XRv = nb_cc_h_XRv[2:end] # number of CC histories by type\n",
    "    ind_cc_h_XRv = ind_cc_h_XRv[2:end] #\n",
    "\n",
    "    nb_cc_h = sum(nb_cc_h_XRv)\n",
    "    ind_cc_h = ind_cc_h_XRv     \n",
    "\n",
    "    starti = 1\n",
    "    endi   = nh[1]\n",
    "    for i in eachindex(share)\n",
    "        share[i] = sum(S_h[ind_cc_h[findall(x -> starti<=x <= endi,ind_cc_h)]])   \n",
    "        starti = 1 + endi\n",
    "        endi = endi + nh[i]\n",
    "    end \n",
    "\n",
    "    \n",
    "    # Constructing the truncated allocation\n",
    "    truncatedAllocation = TruncatedAllocation(\n",
    "        S_h       = S_h,\n",
    "        Π_h       = Π_h,\n",
    "        statDist_h= statDist_h,\n",
    "        y0_h      = y0_h,\n",
    "        ind_y0_h  = ind_y0_h,\n",
    "        a_beg_h   = a_beg_h,\n",
    "        a_end_h   = a_end_h,\n",
    "        c_h       = c_h,\n",
    "        l_h       = l_h,\n",
    "        ly_τ_h    = ly_τ_h,\n",
    "        u_h       = u_h,\n",
    "        u′_h      = u′_h,\n",
    "        u′′_h     = u′′_h,\n",
    "        v_h       = v_h,\n",
    "        v′_h      = v′_h,\n",
    "        resid_E_h = resid_E_h,#resid_E_h = resid_E_h\n",
    "        nb_cc_h   = nb_cc_h,\n",
    "        ind_cc_h  = ind_cc_h,\n",
    "        share     = share)\n",
    "\n",
    "    # Computing the ξs    \n",
    "    ξu0 = u_h./u.(c_h,l_h)\n",
    "    ξu1 = u′_h./u′.(c_h,l_h)\n",
    "    ξu2 = u′′_h./u′′.(c_h,l_h)\n",
    "\n",
    "    ξuE = ((I-β*R*Π_h)\\resid_E_h)./u′.(c_h,l_h)\n",
    "    \n",
    "    ξy  = ly_τ_h./(y0_h.*l_h).^τ\n",
    "    ξv0 = 0.0 .*l_h # v_h./v.(l_h)\n",
    "    ξv1 = 0.0 .*l_h #(1-τ)*w*ly_τ_h.*u′_h./(l_h.*v′.(l_h)) #(equivalent to τ*w*ξy*.(y0_h.*l_h).^τ?*ξu1.*u′.(c_h)./l_h)\n",
    "\n",
    "    ξs = ξs_struct(\n",
    "        ξu0 = ξu0,\n",
    "        ξu1 = ξu1,\n",
    "        ξu2 = ξu2,\n",
    "        ξuE = ξuE,\n",
    "        ξy  = ξy,\n",
    "        ξv0 = ξv0,\n",
    "        ξv1 = ξv1)\n",
    "    \n",
    "    return TruncatedModel(\n",
    "            N                = N,      # common histories\n",
    "            refiNs           = refiNs, # refined histories\n",
    "            Ntot             = length(hs) ,\n",
    "            ind_h            = ind_h,\n",
    "            ind_hs           = ind_hs,\n",
    "            truncatedAllocation = truncatedAllocation,\n",
    "            ξs               = ξs)\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bceb8436",
   "metadata": {},
   "outputs": [],
   "source": [
    "function check_truncation(truncatedModel::TruncatedModel,solution::AiyagariSolution,economy::Economy;\n",
    "        noprint=false)::Nothing\n",
    "\n",
    "    @unpack N,Ntot,ind_h,truncatedAllocation,ξs = truncatedModel\n",
    "    @unpack S_h,Π_h,y0_h,ind_y0_h,a_beg_h,a_end_h,c_h,l_h,ly_τ_h,u_h,u′_h,u′′_h,v_h,v′_h,nb_cc_h,ind_cc_h,statDist_h = (\n",
    "                truncatedAllocation)\n",
    "    @unpack β,α,δ,τ,ys,ny,Sy = economy\n",
    "    @unpack w,R,K,L,G,Y,A = solution\n",
    "            \n",
    "    diffs    = sum(S_h) - one(eltype(S_h))\n",
    "    if !(norm(diffs) < 1e-12)\n",
    "        @warn(\"error in S_h. Diff (sum(S_h) - 1): \", round.(diffs,digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: S_h. Diff |sum(S_h) - 1|: \", round.(norm(diffs),digits=4))\n",
    "    end\n",
    "      \n",
    "    diffs    = sum(statDist_h,dims=1)[1,:] - S_h\n",
    "    if !(norm(diffs) < 1e-12)\n",
    "        @warn(\"error in S_h. Diff (∫ₐΛ̃ₕ(da,h) - S_h): \", round.(diffs,digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: S_h. Diff ||∫ₐΛ̃ₕ(da,h) - S_h||∞: \", round.(norm(diffs),digits=4))\n",
    "    end\n",
    "    \n",
    "    share = zeros(eltype(Sy),size(Sy))\n",
    "    for (i,s) in enumerate(S_h)\n",
    "        share[ind_y0_h[i]] += s\n",
    "    end\n",
    "    diffs    = share - Sy\n",
    "    if !(norm(diffs) < 1e-12)\n",
    "        @warn(\"error in S_h. Diff (share - Sy): \", round.(diffs,digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: S_h. Diff ||share - Sy||∞: \", round.(norm(diffs),digits=4))\n",
    "    end\n",
    "    \n",
    "    diffs    = S_h - Π_h'*S_h\n",
    "    if !(norm(diffs) < 1e-10)\n",
    "        @warn(\"error in S_h. Diff (S_h - Π_h'*S_hb): \", round.(diffs,digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: S_h. Diff ||S_h - Π_h'*S_h||∞: \", round.(norm(diffs),digits=4))\n",
    "    end\n",
    "    diffs    = S_h.*a_beg_h - Π_h'*(S_h.*a_end_h)\n",
    "    if !(norm(diffs) < 1e-10)\n",
    "        @warn(\"error in a_beg_h. Diff (a_beg_h): \", round.(diffs,digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: a_beg_h. Diff ||a_beg_h||∞: \", round.(norm(diffs),digits=4))\n",
    "    end\n",
    "    diffs    = c_h + a_end_h - R*a_beg_h - w*ly_τ_h\n",
    "    if !(norm(diffs) < 1e-10)\n",
    "        @show diffs\n",
    "        @warn(\"error in budget const. Diff (budget const): \", round.(diffs,digits=4))\n",
    "    else\n",
    "        (!noprint)&&println(\"Passed: budget const. Diff ||budget const||∞: \", round.(norm(diffs),digits=4))\n",
    "    end\n",
    "    return nothing\n",
    "end;"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.10.5",
   "language": "julia",
   "name": "julia-1.10"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
